# -*- coding:utf-8 -*-　　
# @name: redis_handler
# @auth: rainy-autumn@outlook.com
# @version:
import asyncio
import json
from loguru import logger
import redis.asyncio as redis

from core.db import *
from core.util import *
import socket
from motor.motor_asyncio import AsyncIOMotorCursor


async def get_redis_pool():
    keep_alive_config = {}
    if sys.platform == 'darwin':  # macOS 平台
        keep_alive_config = {
            'socket_keepalive': True,
            'socket_keepalive_options': {
                socket.TCP_KEEPALIVE: 60,
                socket.TCP_KEEPCNT: 10,
                socket.TCP_KEEPINTVL: 10,
            }
        }
    else:
        keep_alive_config = {
            'socket_keepalive': True,
            'socket_keepalive_options': {
                socket.TCP_KEEPIDLE: 60,
                socket.TCP_KEEPCNT: 10,
                socket.TCP_KEEPINTVL: 10,
            }
        }
    redis_con = await redis.from_url(f"redis://:{REDIS_PASSWORD}@{REDIS_IP}:{REDIS_PORT}", encoding="utf-8",
                                     decode_responses=True, **keep_alive_config)
    try:
        yield redis_con
    finally:
        await redis_con.close()
        await redis_con.connection_pool.disconnect()


async def refresh_config(name, t, content=None):
    data = {
        "name": name,
        "type": t,
    }
    if content is not None:
        data['content'] = content
    async for redis_client in get_redis_pool():
        name_all = []
        if name == "all":
            keys = await redis_client.keys("node:*")
            for key in keys:
                tmp_name = key.split(":")[1]
                hash_data = await redis_client.hgetall(key)
                if hash_data.get('state') != '3':
                    name_all.append(tmp_name)
        else:
            name_all.append(name)
        for n in name_all:
            await redis_client.rpush(f"refresh_config:{n}", json.dumps(data))


async def subscribe_log_channel():
    channel_name = 'logs'
    logger.info(f"Subscribed to channel {channel_name}")
    while True:
        try:
            async for redis_client in get_redis_pool():
                async with redis_client.pubsub() as pubsub:
                    await pubsub.psubscribe(channel_name)
                    while True:
                        message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=3)
                        if message is not None:
                            data = json.loads(message["data"])
                            logger.info("Received message:" + json.dumps(data))
                            log_name = data["name"]
                            if log_name in GET_LOG_NAME:
                                if log_name not in LOG_INFO:
                                    LOG_INFO[log_name] = []
                                LOG_INFO[log_name].append(data['log'])
                            if "Register Success" in data['log']:
                                await check_node_task(log_name, redis_client)
                            await redis_client.rpush(f'log:{log_name}', data['log'])
                            total_logs = await redis_client.llen(f'log:{log_name}')
                            if total_logs > TOTAL_LOGS:
                                await redis_client.delete(f'log:{log_name}')
        except Exception as e:
            logger.error(f"An error occurred: {e}. Reconnecting...")
            await asyncio.sleep(1)  # 等待一段时间后尝试重新连接


async def check_node_task(node_name, redis_conn):
    async for mongo_client in get_mongo_db():
        query = {
            "progress": {"$ne": 100},
            "$or": [
                {"node": node_name},
                {"allNode": True}
            ]
        }
        cursor: AsyncIOMotorCursor = mongo_client.task.find(query)
        result = await cursor.to_list(length=None)
        if len(result) == 0:
            return
        # Process the result as needed
        response_data = []
        for doc in result:
            doc["id"] = str(doc["_id"])
            await check_redis_task_target_is_null(doc["id"], doc["target"], redis_conn)
            response_data.append(doc)
        for r in response_data:
            add_redis_task_data = transform_db_redis(r)
            await redis_conn.rpush(f"NodeTask:{node_name}", json.dumps(add_redis_task_data))
        return


async def check_redis_task_target_is_null(id, target, redis_conn):
    flag = await redis_conn.exists("TaskInfo:{}".format(id))
    if flag:
        return
    else:
        from_check = False
        r = {}
        if target == "":
            from_check = True
            async for mongo_client in get_mongo_db():
                r = await mongo_client.task.find_one({"_id": ObjectId(id)})
                target = r.get("target", "")
        task_target = []
        for t in target.split("\n"):
            key = f"TaskInfo:progress:{id}:{t}"
            res = await redis_conn.hgetall(key)
            if "scan_end" in res:
                continue
            else:
                task_target.append(t)
        await redis_conn.lpush(f"TaskInfo:{id}", *task_target)
        if from_check:
            try:
                if len(r) != 0:
                    if r['allNode']:
                        r["node"] = await get_redis_online_data(redis_conn)
                    add_redis_task_data = transform_db_redis(r)
                    for name in r["node"]:
                        await redis_conn.rpush(f"NodeTask:{name}", json.dumps(add_redis_task_data))
            except Exception as e:
                logger.error(str(e))
        return


async def get_redis_online_data(redis_con):
    async with redis_con as redis:
        # 获取所有以 node: 开头的键
        keys = await redis.keys("node:*")
        # 构建结果字典
        result = []
        for key in keys:
            name = key.split(":")[1]
            hash_data = await redis.hgetall(key)
            if hash_data.get('state') == '1':
                result.append(name)
        return result
